/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.iotdb.db.queryengine.plan.relational.sql.parser;

import com.google.common.collect.ImmutableSet;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.NoViableAltException;
import org.antlr.v4.runtime.Parser;
import org.antlr.v4.runtime.RecognitionException;
import org.antlr.v4.runtime.Recognizer;
import org.antlr.v4.runtime.RuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.TokenStream;
import org.antlr.v4.runtime.Vocabulary;
import org.antlr.v4.runtime.atn.ATN;
import org.antlr.v4.runtime.atn.ATNState;
import org.antlr.v4.runtime.atn.NotSetTransition;
import org.antlr.v4.runtime.atn.PrecedencePredicateTransition;
import org.antlr.v4.runtime.atn.RuleStartState;
import org.antlr.v4.runtime.atn.RuleStopState;
import org.antlr.v4.runtime.atn.RuleTransition;
import org.antlr.v4.runtime.atn.Transition;
import org.antlr.v4.runtime.atn.WildcardTransition;
import org.antlr.v4.runtime.misc.IntervalSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import static com.google.common.base.MoreObjects.firstNonNull;
import static org.antlr.v4.runtime.atn.ATNState.RULE_START;

public class ErrorHandler extends BaseErrorListener {

  private static final Logger LOGGER = LoggerFactory.getLogger(ErrorHandler.class);

  private final Map<Integer, String> specialRules;
  private final Map<Integer, String> specialTokens;
  private final Set<Integer> ignoredRules;

  private ErrorHandler(
      Map<Integer, String> specialRules,
      Map<Integer, String> specialTokens,
      Set<Integer> ignoredRules) {
    this.specialRules = new HashMap<>(specialRules);
    this.specialTokens = specialTokens;
    this.ignoredRules = new HashSet<>(ignoredRules);
  }

  @Override
  public void syntaxError(
      Recognizer<?, ?> recognizer,
      Object offendingSymbol,
      int line,
      int charPositionInLine,
      String message,
      RecognitionException e) {
    try {
      Parser parser = (Parser) recognizer;

      Result result = analyze(e, parser);

      // pick the candidate tokens associated largest token index processed (i.e., the path that
      // consumed the most input)
      String expected = result.getExpected().stream().sorted().collect(Collectors.joining(", "));

      message =
          String.format(
              "mismatched input '%s'. Expecting: %s",
              parser.getTokenStream().get(result.getErrorTokenIndex()).getText(), expected);
    } catch (Exception exception) {
      LOGGER.warn(
          "Unexpected failure when handling parsing error. This is likely a bug in the implementation",
          exception);
    }

    throw new ParsingException(message, e, line, charPositionInLine + 1);
  }

  private Result analyze(RecognitionException e, Parser parser) {
    ATN atn = parser.getATN();

    ATNState currentState;
    Token currentToken;
    RuleContext context;

    if (e != null) {
      currentState = atn.states.get(e.getOffendingState());
      currentToken = e.getOffendingToken();
      context = e.getCtx();

      if (e instanceof NoViableAltException) {
        currentToken = ((NoViableAltException) e).getStartToken();
      }
    } else {
      currentState = atn.states.get(parser.getState());
      currentToken = parser.getCurrentToken();
      context = parser.getContext();
    }

    Analyzer analyzer = new Analyzer(parser, specialRules, specialTokens, ignoredRules);
    return analyzer.process(currentState, currentToken.getTokenIndex(), context);
  }

  private static class ParsingState {
    public final ATNState state;
    public final int tokenIndex;
    public final boolean suppressed;
    public final Parser parser;

    public ParsingState(ATNState state, int tokenIndex, boolean suppressed, Parser parser) {
      this.state = state;
      this.tokenIndex = tokenIndex;
      this.suppressed = suppressed;
      this.parser = parser;
    }

    @Override
    public boolean equals(Object o) {
      if (this == o) {
        return true;
      }
      if (o == null || getClass() != o.getClass()) {
        return false;
      }
      ParsingState that = (ParsingState) o;
      return tokenIndex == that.tokenIndex && state.equals(that.state);
    }

    @Override
    public int hashCode() {
      return Objects.hash(state, tokenIndex);
    }

    @Override
    public String toString() {
      Token token = parser.getTokenStream().get(tokenIndex);

      String text = firstNonNull(token.getText(), "?");
      text = text.replace("\\", "\\\\");
      text = text.replace("\n", "\\n");
      text = text.replace("\r", "\\r");
      text = text.replace("\t", "\\t");

      return String.format(
          "%s%s:%s @ %s:<%s>:%s",
          suppressed ? "-" : "+",
          parser.getRuleNames()[state.ruleIndex],
          state.stateNumber,
          tokenIndex,
          parser.getVocabulary().getSymbolicName(token.getType()),
          text);
    }
  }

  private static class Analyzer {
    private final Parser parser;
    private final ATN atn;
    private final Vocabulary vocabulary;
    private final Map<Integer, String> specialRules;
    private final Map<Integer, String> specialTokens;
    private final Set<Integer> ignoredRules;
    private final TokenStream stream;

    private int furthestTokenIndex = -1;
    private final Set<String> candidates = new HashSet<>();

    private final Map<ParsingState, Set<Integer>> memo = new HashMap<>();

    public Analyzer(
        Parser parser,
        Map<Integer, String> specialRules,
        Map<Integer, String> specialTokens,
        Set<Integer> ignoredRules) {
      this.parser = parser;
      this.stream = parser.getTokenStream();
      this.atn = parser.getATN();
      this.vocabulary = parser.getVocabulary();
      this.specialRules = specialRules;
      this.specialTokens = specialTokens;
      this.ignoredRules = ignoredRules;
    }

    public Result process(ATNState currentState, int tokenIndex, RuleContext context) {
      RuleStartState startState = atn.ruleToStartState[currentState.ruleIndex];

      if (isReachable(currentState, startState)) {
        // We've been dropped inside a rule in a state that's reachable via epsilon transitions.
        // This is,
        // effectively, equivalent to starting at the beginning (or immediately outside) the rule.
        // In that case, backtrack to the beginning to be able to take advantage of logic that
        // remaps
        // some rules to well-known names for reporting purposes
        currentState = startState;
      }

      Set<Integer> endTokens =
          process(new ParsingState(currentState, tokenIndex, false, parser), 0);
      while (!endTokens.isEmpty() && context.invokingState != -1) {
        ATNState nextState =
            ((RuleTransition) atn.states.get(context.invokingState).transition(0)).followState;
        endTokens =
            endTokens.stream()
                .flatMap(
                    endToken ->
                        process(new ParsingState(nextState, endToken, false, parser), 0).stream())
                .collect(Collectors.toSet());
        context = context.parent;
      }

      return new Result(furthestTokenIndex, candidates);
    }

    private boolean isReachable(ATNState target, RuleStartState from) {
      Deque<ATNState> activeStates = new ArrayDeque<>();
      activeStates.add(from);

      while (!activeStates.isEmpty()) {
        ATNState current = activeStates.pop();

        if (current.stateNumber == target.stateNumber) {
          return true;
        }

        for (int i = 0; i < current.getNumberOfTransitions(); i++) {
          Transition transition = current.transition(i);

          if (transition.isEpsilon()) {
            activeStates.push(transition.target);
          }
        }
      }

      return false;
    }

    private Set<Integer> process(ParsingState start, int precedence) {
      Set<Integer> result = memo.get(start);
      if (result != null) {
        return result;
      }

      ImmutableSet.Builder<Integer> endTokens = ImmutableSet.builder();

      // Simulates the ATN by consuming input tokens and walking transitions.
      // The ATN can be in multiple states (similar to an NFA)
      Deque<ParsingState> activeStates = new ArrayDeque<>();
      activeStates.add(start);

      while (!activeStates.isEmpty()) {
        ParsingState current = activeStates.pop();

        ATNState state = current.state;
        int tokenIndex = current.tokenIndex;
        boolean suppressed = current.suppressed;

        while (stream.get(tokenIndex).getChannel() == Token.HIDDEN_CHANNEL) {
          // Ignore whitespace
          tokenIndex++;
        }
        int currentToken = stream.get(tokenIndex).getType();

        if (state.getStateType() == RULE_START) {
          int rule = state.ruleIndex;

          if (specialRules.containsKey(rule)) {
            if (!suppressed) {
              recordLabel(tokenIndex, specialRules.get(rule));
            }
            suppressed = true;
          } else if (ignoredRules.contains(rule)) {
            // TODO expand ignored rules like we expand special rules
            continue;
          }
        }

        if (state instanceof RuleStopState) {
          endTokens.add(tokenIndex);
          continue;
        }

        for (int i = 0; i < state.getNumberOfTransitions(); i++) {
          Transition transition = state.transition(i);

          if (transition instanceof RuleTransition) {
            RuleTransition ruleTransition = (RuleTransition) transition;
            for (int endToken :
                process(
                    new ParsingState(ruleTransition.target, tokenIndex, suppressed, parser),
                    ruleTransition.precedence)) {
              activeStates.push(
                  new ParsingState(
                      ruleTransition.followState,
                      endToken,
                      suppressed && endToken == currentToken,
                      parser));
            }
          } else if (transition instanceof PrecedencePredicateTransition) {
            if (precedence < ((PrecedencePredicateTransition) transition).precedence) {
              activeStates.push(
                  new ParsingState(transition.target, tokenIndex, suppressed, parser));
            }
          } else if (transition.isEpsilon()) {
            activeStates.push(new ParsingState(transition.target, tokenIndex, suppressed, parser));
          } else if (transition instanceof WildcardTransition) {
            throw new UnsupportedOperationException("not yet implemented: wildcard transition");
          } else {
            IntervalSet labels = transition.label();

            if (transition instanceof NotSetTransition) {
              labels =
                  labels.complement(IntervalSet.of(Token.MIN_USER_TOKEN_TYPE, atn.maxTokenType));
            }

            // Surprisingly, TokenStream (i.e. BufferedTokenStream) may not have loaded all the
            // tokens from the
            // underlying stream. TokenStream.get() does not force tokens to be buffered -- it just
            // returns what's
            // in the current buffer, or fail with an IndexOutOfBoundsError. Since Antlr decided the
            // error occurred
            // within the current set of buffered tokens, stop when we reach the end of the buffer.
            if (labels.contains(currentToken) && tokenIndex < stream.size() - 1) {
              activeStates.push(new ParsingState(transition.target, tokenIndex + 1, false, parser));
            } else {
              if (!suppressed) {
                recordLabel(tokenIndex, getTokenNames(labels));
              }
            }
          }
        }
      }

      result = endTokens.build();
      memo.put(start, result);
      return result;
    }

    private void recordLabel(int tokenIndex, String label) {
      recordLabel(tokenIndex, ImmutableSet.of(label));
    }

    private void recordLabel(int tokenIndex, Set<String> labels) {
      if (tokenIndex >= furthestTokenIndex) {
        if (tokenIndex > furthestTokenIndex) {
          candidates.clear();
          furthestTokenIndex = tokenIndex;
        }

        candidates.addAll(labels);
      }
    }

    private Set<String> getTokenNames(IntervalSet tokens) {
      Set<String> names = new HashSet<>();
      for (int i = 0; i < tokens.size(); i++) {
        int token = tokens.get(i);
        if (token == Token.EOF) {
          names.add("<EOF>");
        } else {
          names.add(specialTokens.getOrDefault(token, vocabulary.getDisplayName(token)));
        }
      }

      return names;
    }
  }

  public static Builder builder() {
    return new Builder();
  }

  public static class Builder {
    private final Map<Integer, String> specialRules = new HashMap<>();
    private final Map<Integer, String> specialTokens = new HashMap<>();
    private final Set<Integer> ignoredRules = new HashSet<>();

    public Builder specialRule(int ruleId, String name) {
      specialRules.put(ruleId, name);
      return this;
    }

    public Builder specialToken(int tokenId, String name) {
      specialTokens.put(tokenId, name);
      return this;
    }

    public Builder ignoredRule(int ruleId) {
      ignoredRules.add(ruleId);
      return this;
    }

    public ErrorHandler build() {
      return new ErrorHandler(specialRules, specialTokens, ignoredRules);
    }
  }

  private static class Result {
    private final int errorTokenIndex;
    private final Set<String> expected;

    public Result(int errorTokenIndex, Set<String> expected) {
      this.errorTokenIndex = errorTokenIndex;
      this.expected = expected;
    }

    public int getErrorTokenIndex() {
      return errorTokenIndex;
    }

    public Set<String> getExpected() {
      return expected;
    }
  }
}
